{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DDUNCwVslByF"
      },
      "source": [
        "# LoRA (Finetuning)\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/lora_finetuning.ipynb)\n",
        "\n",
        "This is an example on fine-tuning Gemma with LoRA. It's best to first read the [finetuning](https://github.com/google-deepmind/gemma/blob/main/docs/finetuning.md) colab to understand this one.\n",
        "\n",
        "See the [LoRA sampling](https://github.com/google-deepmind/gemma/blob/main/docs/lora_sampling.md) tutorial if you just want to do inference with LoRA.\n"
      ]
    },
    {
      "metadata": {
        "id": "FgQUqR_cMZ2h"
      },
      "cell_type": "code",
      "source": [
        "!pip install -q gemma"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TZQiYQy7EJe3"
      },
      "outputs": [],
      "source": [
        "# Common imports\n",
        "import os\n",
        "import optax\n",
        "import treescope\n",
        "\n",
        "# Gemma imports\n",
        "from kauldron import kd\n",
        "from gemma import gm"
      ]
    },
    {
      "metadata": {
        "id": "vjp0xiR13fj6"
      },
      "cell_type": "markdown",
      "source": [
        "By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
      ]
    },
    {
      "metadata": {
        "id": "v4XpXYV13fj6"
      },
      "cell_type": "code",
      "source": [
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5Yi0BzpgEMwS"
      },
      "source": [
        "\n",
        "## Config updates\n",
        "\n",
        "If you're familiar with the [finetuning](https://github.com/google-deepmind/gemma/blob/main/docs/finetuning.md) tutorial, switching to LoRA only require 3 changes to the trainer.\n",
        "\n",
        "For an end-to-end example, see\n",
        "[lora.py](https://github.com/google-deepmind/gemma/tree/main/examples/lora.py) config."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GV8_s4cTS7r8"
      },
      "source": [
        "### 1. Model\n",
        "\n",
        "Wrap the model in the `gm.nn.LoRA`. This will apply model surgery to replace all the linear and compatible layers with LoRA layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V1hVhUsNTDNs"
      },
      "outputs": [],
      "source": [
        "model = gm.nn.LoRA(\n",
        "    rank=4,\n",
        "    model=gm.nn.Gemma3_4B(tokens=\"batch.input\", text_only=True),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bWpNYlosfA4p"
      },
      "source": [
        "Internally, this uses the [`gemma.peft`](https://github.com/google-deepmind/gemma/blob/main/gemma/peft) mini-library to perform model surgery."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IVgOieKuS9Ac"
      },
      "source": [
        "### 2. Checkpoint\n",
        "\n",
        "Wrap the init transform in a `gm.ckpts.SkipLoRA`. The wrapper is required because the param structure with and without LoRA is different.\n",
        "\n",
        "Only the initial pretrained weights are loaded, but the LoRA weights are kept to their random initialization."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EIQ3yOhZe8Ep"
      },
      "outputs": [],
      "source": [
        "init_transform = gm.ckpts.SkipLoRA(\n",
        "    wrapped=gm.ckpts.LoadCheckpoint(\n",
        "        path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,\n",
        "    ),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NjJLqi_Xgrhh"
      },
      "source": [
        "Note: If you're loading the weights directly with `gm.ckpts.load_params`, you can use the `peft.split_params` and `peft.merge_params` instead. See [LoRA sampling](https://github.com/google-deepmind/gemma/blob/main/docs/lora_sampling.md) for an example."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M9uec_APS_oi"
      },
      "source": [
        "### 3. Optimizer\n",
        "\n",
        "Finally, we add a mask to the optimizer (with `kd.optim.partial_updates`), so only the LoRA weights are trained."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g7HJYT4ae8nV"
      },
      "outputs": [],
      "source": [
        "optimizer = kd.optim.partial_updates(\n",
        "    optax.adafactor(learning_rate=0.005),\n",
        "    # We only optimize the LoRA weights. The rest of the model is frozen.\n",
        "    mask=kd.optim.select(\"lora\"),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IY-2uskIj5z7"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2hWTB015lM0Z"
      },
      "source": [
        "### Data pipeline\n",
        "\n",
        "Like for the [finetuning](https://github.com/google-deepmind/gemma/blob/main/docs/finetuning.md) example, we recreate the tokenizer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 537,
          "status": "ok",
          "timestamp": 1739286607692,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "bywIkAHklSlX",
        "outputId": "0b3c0d27-1707-4566-bb4c-473d21cb78df"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[\u003c_Gemma2SpecialTokens.BOS: 2\u003e, 1596, 603, 671, 3287, 13060]"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "tokenizer = gm.text.Gemma3Tokenizer()\n",
        "\n",
        "tokenizer.encode('This is an example sentence', add_bos=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2_J-Op0DlSNv"
      },
      "source": [
        "And the data pipeline:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 211
        },
        "executionInfo": {
          "elapsed": 6724,
          "status": "ok",
          "timestamp": 1739286745151,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "qf3-uXF6n2e0",
        "outputId": "58e258d5-5dfb-465f-8a1f-11e1da507ffa"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Disabling pygrain multi-processing (unsupported in colab).\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003cscript\u003e (()=\u003e{ if (customElements.get('treescope-container') === undefined) { class TreescopeContainer extends HTMLElement { constructor() { super(); this.attachShadow({mode: \"open\"}); this.defns = {}; this.state = {}; } } customElements.define(\"treescope-container\", TreescopeContainer); } if (customElements.get('treescope-run-here') === undefined) { class RunHere extends HTMLElement { constructor() { super() } connectedCallback() { const run = child =\u003e { const fn = new Function(child.textContent); child.textContent = \"\"; fn.call(this); this.remove(); }; const child = this.querySelector(\"script\"); if (child) { run(child); } else { new MutationObserver(()=\u003e{ run(this.querySelector(\"script\")); }).observe(this, {childList: true}); } } } customElements.define(\"treescope-run-here\", RunHere); } })(); \u003c/script\u003e \u003ctreescope-container class=\"treescope_out_7cb0e60d475047eab660824e7402d7e1\" \u003e\u003c/treescope-container\u003e \u003ctreescope-run-here\u003e\u003cscript type=\"application/octet-stream\"\u003e const root = ( Array.from(document.getElementsByClassName( \"treescope_out_7cb0e60d475047eab660824e7402d7e1\")) .filter((elt) =\u003e !elt.dataset.setup) )[0]; root.dataset.setup = 1; const msg = document.createElement(\"span\"); msg.style = \"color: #cccccc; font-family: monospace;\"; msg.textContent = \"(Loading...)\"; root.state.loadingMsg = msg; root.shadowRoot.appendChild(msg); root.state.chain = new Promise((resolve, reject) =\u003e { const observer = new IntersectionObserver((entries) =\u003e { for (const entry of entries) { if (entry.isIntersecting) { resolve(); observer.disconnect(); return; } } }, {rootMargin: \"1000px\"}); window.setTimeout(() =\u003e { observer.observe(root); }, 0); }); root.state.deferring = false; const _insertNode = (node) =\u003e { for (let oldScript of node.querySelectorAll(\"script\")) { let newScript = document.createElement(\"script\"); newScript.type = oldScript.type; newScript.textContent = oldScript.textContent; oldScript.parentNode.replaceChild(newScript, oldScript); } if (root.state.loadingMsg) { root.state.loadingMsg.remove(); root.state.loadingMsg = null; } root.shadowRoot.appendChild(node); }; root.defns.insertContent = ((contentNode, compressed) =\u003e { if (compressed) { root.state.deferring = true; } if (root.state.deferring) { root.state.chain = (async () =\u003e { await root.state.chain; if (compressed) { const encoded = contentNode.textContent; const blob = new Blob([ Uint8Array.from(atob(encoded), (m) =\u003e m.codePointAt(0)) ]); const reader = blob.stream().pipeThrough( new DecompressionStream(\"deflate\") ).pipeThrough( new TextDecoderStream(\"utf-8\") ).getReader(); const parts = []; while (true) { const step = await reader.read(); if (step.done) { break; } parts.push(step.value); } const tpl = document.createElement('template'); tpl.innerHTML = parts.join(\"\"); _insertNode(tpl.content); } else { _insertNode(contentNode.content); } })(); } else { _insertNode(contentNode.content); } }); \u003c/script\u003e\u003c/treescope-run-here\u003e\u003cdiv style=\"display:none\"\u003e \u003cscript type=\"application/octet-stream\" \u003eeNrtGOtWGkn6/z5FhZwzwiotIKBi9GyD3EzUKCYaMzlM0V3dXdJUtdXVIM7xAeY9dl9snmS/quYOmmTjZrMzA+fQdNV3r+9ar0I59MmBIQUhocUD0hacS/QrCnhIJeWshATxsaR9socczmTawT3qD0uoxxkPA2zB+sCjkqT1SwkFAlZ8Gsq0Jp2WwwBWGWew3MFW1xU8Ynba4j4XpRh1D43eOj4AAD1qS6+EHCoBjEnC5B4KsG1T5qZ94sgSylmeYsJI2iPU9WAlaxQUGSYxBZknaKM/6T4NaYf6VILkOJJ8ApumTArKQmqlQ3pP4t2RuA+vNmPzvJqYJy0iBjwFrIWWoIFESr/9NRwEPrWwstgmtyRR2guCe2sHyWRq/wAMCvxCiWzisBDtI+nR0HCJPAdrn3CbJFOGx0Np6H1QjUjUDghTKpuWoqqQPn5atdPAzPYJbLPI9/diDgaI2eKcwWpywEU3hWZl4JewpLbmliW11GJAhMNFDzOLGIwPkil9vsAgubSD0jHSK7SVSwEd6qDkgtSGT5grPbS/jzIK5EnRBZGRYGB3RPyQTAXzIqYkWyQdetSRSj4NoP48wPcRDknwKmbzgSHIbURCaTLa08dVE7hHkrFNUorG3hKjIAq92Ix7K3Qcs9iP1XhCyy+XQUkRH6TkruvHUdnWkQPeGihaaoX4cgORPjj46CSVdPrd6JKhMnpCJJRAI2DD8nEYvoHgHNFNJiY02z1ww8SY+UMK7Anur3384NXmqgCwaR9pgvuJ+fSRQBJ3QFNyt5/IJCB0hVwG4QxEBGMw2HoqGFZbIKlwxronIBjjNKbzSBt3OoL0tf/otPKyuJPDmQxoNQKweK8HiO2BR1ib3AVwQMQuMS6ThsN9G3eAHwNBSh4Okwc+7hD/YH6nHUulcUqWR6wusVMp9PfUDFesP1OuoUo07izAVraQLSiAOdqPZN8HtFIEgLZpGPh4OM6yi4DoAGkVSqUOgQAmMxJY+rO3kl+cadNZlWpHGRlMOOFFmc6/HZ+rlP0oT4/3iVjmbGPRDQl2wSnYMvYzHcVEBoW6GmkMPyehLhkltPZzrtCx1v6X4s0jPSpk8TsIqc5RMY5EqA4w4FA4iVjBl4bPx1aHgmaU1qEePubjz8N1qp4kd3KZi0HDtkNFKNuctZX7rwitp0LJyBVUNK08KvTN4scnviii0qqHhQutTiyGDuiHb+QGKS0YdiIpocdYlYCm26ucNoESC1BgSOgfVwP/TLJ5O7HQfq4dY/AKin3UGvY63A/RaSSVvjaqxJjwDIYQGOkB6XShldToYQ/KiAdpGBo9JgGd4pDYkwb0Jcmo796ym8fYukPMGLukt6hlHB8rtFid7qaYxgCHbQuaRzDsBB87Mo62hTz9FM8FnHmWs6ZHfSyS6bSNJU5jBgere5DU7LJiohorgdnYmzVZlA0RAYtB15zmkfw6VSYSwMFQYr+Yl0SzRC9oL+BCYrZEuyN4F6q1Wpkmo89bdwZtxp7jY34wVJsCgtltC5pcWxA2EnV+jACa84DP1ixMQmdUSOdC1cK+lYRZBxrtbHCnezMjlFjhT+T9r0nS4cKGvBtLYnMJuispZo0Hs15f91ECB0BheQ58gCYdC5B2gAWDqGuPs/r4IBwHW9mtFYABVRTHU58YDXkqdY0sNFpKZwydU6ejZUkPhlikXYFtCuIlUXarYBN3A3HwZ5egDIRw0fI2Yv+G9lRlC72ERjZekmUprT5PzkZL2Xmsz4Ohm1Owk48DSFGfb/O+Poc/ziFu7jSPuR75EZjnkGMVi2ccw9VkhF6YQuCh4Qjeg2nLitQQYKjMHhp97MNIlkyljJDDLKbzvZqp1NOIi46ap76w7CTW4GRTkwk29AiM7TAfkgGqtFotpU1LramhVW/CSAgqW6Q1ZFbyl3+MSp2lSsh/WPbigqU4KTvC2O6P1gajG5O8mohCYZVQJPykysEltb854I6T2+tAli/mN+zMbv3YNcum/jTPTJPrf+XzAfw2aqZZNZ/6lHum6Xb5a7tZLVcGH0zz4kPlyDxulitmzb1rNt54MiwfU+Ju1Q6vcm+axQ/9VhDRt8eFi+zRVfP8/XH/8vhevh3WapX1S7d7QcuHGY8enkVHVbt+k2l0Np1+0w5uXxe920tKz6JjVvcazjtpviuWT0TerDVZt1q03kURWz8v3Fphd9B3av7m7Z1b5Ttu52hQ38k2zE1mnhfeCHGUPV937zPndsY8crLuyXZlUL/JuRk+jM63t3vVbHHQuNo9dd2AXHSHedLs3BesjjitS2y6Z82TwSEOh+FZ1GxeXVZrA/PtWdD8YL/b3Fx3ty+2r7Zkxnn99tbsF4DmG/Nk2zwemD33/ry1Hl23SPXqLucUrfuT/HljWIjK5uv78k1QC7Zo46xSzVxHb/OtbeaU31QbteOeSdd3+tWcx7Le9nrn/eDqZtAQ/cP6uwq7capVV66fWte+v13YrRwNyjvebv74uN7aql+bbq9ZuCmf7cqLOmnsVsvlZn3r0M2fb36whh2zDmf6/vWmeVbHJjmu+GbjvnrqXku3WH7rnp42D8tdelYgtfJVpVyzaCbwBA8Y+EZwXT3M3me7LafiSG/4mjVsXAsbTuakV6+eFMu2efv+fYBl2Lru2Tamuznnfjf/jt7cFoOeKJ7yD5UWFfVe/6i+1bpsbdWqOat85lysN3we1PO1cFDA7m1xh16T1okfXLJyo0nsY0Giy9t6pZe9rIluq3VXyBUvL8OBCRKlkL6yksk17dZrKr//Aj+T6Mc2D6CyTUNSX7QZhvEExEYcs5+A1tNXF56++dHNR9wXAW1wD2ahZNyezN/LQQhecBW+ADZqX9RaCOlBkVD9mGpi8ABTiRjuUxdLLgygHHQ4FrYxEFCAL2B0SU5pgbIjWtPLH6iqycRMs6aufYDLBe0R6OqS43vBJTxBetB5LaE+bKBcJpPRtROSL5TRpB47VvOd6cgSU+HUwDXOYOqmLIFeohqmPiQ2yZECfqEzG1RLFmEfsjEFmxFsq4ZyfdZ2oyusz1xeqeZ0fHs128Yk5rfmLzUWmwAAjuvcK8qCaFSEErqydfhdYiWRURGEzbgAgnwa+eDXV5sjYrMCLPWniaf2FzZn750SBz+9vMtt72lm8V8wEQAflNDjKj+rhitki6/qVoq98pJuAXK+gwEVfTnWahluSuMlGgGxwGAwPUBTAO4ki/nkjnblFPr4+2//ymyg33/7Z26r0N7K735C90TwUradKeRUX6LfCvmdp8Vxp+IsPFYc9VOKfq0baJ2SHz8ihHIbCGUzxQ2ULeYLGwhSGywglJn9/bTxNxR/vgeG2v6e/L4Y45PK7Bg8fOwHqUdOTz82njx79AV+OLu3Nh1a1xBnFZVe99e+sr7oa/HUGppM1PuJj2s6+NY+JZDuBvcTMwN3Cf10G3G5NwGKX/fQ0vUAtJ46u0L19uD/SPaDp+zzbXnK52HY7uGw+11y1f9Hjupw7o9cExw5Bf1AREq53SxysB8SSE5bmd0fPCF9rClRp7GJPr8wmy7QF+PDyiT+/zRMNYE/n9rfkeliiVBx+IeoEpN8+7lKMQP441QLiQXo/1epWNnO6lrxR+loP6o2bjYwVUM3+w77O4/G+Xxj+Mj7XGb5I7NbKBd/WfYb2X11cXgqWjZ+mNIQJ9fP1YUx1LMWhfjxsLhq0/7BvwFDVS+s\u003c/script\u003e \u003ctreescope-run-here\u003e\u003cscript type=\"application/octet-stream\"\u003e const root = ( Array.from(document.getElementsByClassName( \"treescope_out_7cb0e60d475047eab660824e7402d7e1\")) .filter((elt) =\u003e !elt.dataset['step0']) )[0]; root.dataset['step0'] = 1; root.defns.insertContent( this.parentNode.querySelector('script[type=\"application/octet-stream\"]'), true ); this.parentNode.remove(); \u003c/script\u003e\u003c/treescope-run-here\u003e \u003c/div\u003e"
            ],
            "text/plain": [
              "{\n",
              "  'input': # np.ndarray int64(8, 200) [≥0, ≤235_349] zero:1_052 nonzero:548\n",
              "    array([[   2,  106, 1645, ...,    0,    0,    0],\n",
              "           [   2,  106, 1645, ...,    0,    0,    0],\n",
              "           [   2,  106, 1645, ...,    0,    0,    0],\n",
              "           ...,\n",
              "           [   2,  106, 1645, ...,    0,    0,    0],\n",
              "           [   2,  106, 1645, ...,    0,    0,    0],\n",
              "           [   2,  106, 1645, ...,    0,    0,    0]], shape=(8, 200))\n",
              "  ,\n",
              "  'loss_mask': \u003cnp.ndarray bool(8, 200, 1) true:291 false:1_309\u003e,\n",
              "  'target': \u003cnp.ndarray int64(8, 200, 1) [≥0, ≤235_349] zero:1_052 nonzero:548\u003e,\n",
              "}"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "ds = kd.data.py.Tfds(\n",
        "    name='mtnt/en-fr',\n",
        "    split='train',\n",
        "    shuffle=True,\n",
        "    batch_size=8,\n",
        "    transforms=[\n",
        "        # Create the model inputs/targets/loss_mask.\n",
        "        gm.data.Seq2SeqTask(\n",
        "            # Select which field from the dataset to use.\n",
        "            # https://www.tensorflow.org/datasets/catalog/mtnt\n",
        "            in_prompt='src',\n",
        "            in_response='dst',\n",
        "            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}\n",
        "            out_input='input',\n",
        "            out_target='target',\n",
        "            out_target_mask='loss_mask',\n",
        "            tokenizer=tokenizer,\n",
        "            # Padding parameters\n",
        "            max_length=200,\n",
        "            truncate=True,\n",
        "        ),\n",
        "    ],\n",
        ")\n",
        "\n",
        "ex = ds[0]\n",
        "\n",
        "treescope.show(ex)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3ny2J07G2X7i"
      },
      "source": [
        "We can decode an example from the batch to inspect the model input and check it is properly formatted:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 53,
          "status": "ok",
          "timestamp": 1739286771922,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "Ep2uhBLh07cw",
        "outputId": "806efa9a-88e9-4f13-db14-b902c3f0caf6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u003cstart_of_turn\u003euser\n",
            "As far as battle mode, 64 is the best.\u003cend_of_turn\u003e\n",
            "\u003cstart_of_turn\u003emodel\n",
            "En ce qui concerne le mode bataille, 64 est le meilleur.\n"
          ]
        }
      ],
      "source": [
        "text = tokenizer.decode(ex['input'][0])\n",
        "\n",
        "print(text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L_ND9CJDlcSy"
      },
      "source": [
        "### Trainer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S3fXHa_4LnEH"
      },
      "source": [
        "We then create the trainer, reusing the `model`, `init_transform` and `optimizer` created above:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bv854FDSn7Z-"
      },
      "outputs": [],
      "source": [
        "trainer = kd.train.Trainer(\n",
        "    seed=42,  # The seed of enlightenment\n",
        "    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default\n",
        "    # Dataset\n",
        "    train_ds=ds,\n",
        "    # Model\n",
        "    model=model,\n",
        "    init_transform=init_transform,\n",
        "    # Training parameters\n",
        "    num_train_steps=500,\n",
        "    train_losses={\n",
        "        \"loss\": kd.losses.SoftmaxCrossEntropyWithIntLabels(\n",
        "            logits=\"preds.logits\",\n",
        "            labels=\"batch.target\",\n",
        "            mask=\"batch.loss_mask\",\n",
        "        ),\n",
        "    },\n",
        "    optimizer=optimizer,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xd1RcRekMkRG"
      },
      "source": [
        "Trainning can be launched with the `.train()` method.\n",
        "\n",
        "Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 153,
          "referenced_widgets": [
            "c24f4e4e33314fec95220d30dc122b7e",
            "eff97ec5bfb0423b873e987019c162be",
            "1d2dbb7248334e68b53ef5d85db6e8c2",
            "f75d78091d504c0a98b53697c0fca63c",
            "f7ff1c9852834af2b43853af188258ce",
            "001cd6e73ce1481ba9d534f22b3b3ad3",
            "9a2ab10ffa954bf69d1c2616d7fa3f9c",
            "fee3501cc1f24a1baa24b6fa928edabf",
            "193d3862469140bc90fdef13fe7bbdfb",
            "4d9f4c5bbf4644cb8b5a202606df0e33",
            "bac38805ce4a4bf186510a17c5796334"
          ]
        },
        "executionInfo": {
          "elapsed": 96534,
          "status": "ok",
          "timestamp": 1739287519167,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "xvIDsFPz75GT",
        "outputId": "1d46761b-9b52-4988-9eb3-ea9c58518b71"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Configuring ...\n",
            "Initializing ...\n",
            "Starting training loop at step 0\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c24f4e4e33314fec95220d30dc122b7e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "train:   0%|          | 0/501 [00:00\u003c?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "state, aux = trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4_VfxhhKRvDd"
      },
      "source": [
        "## Checkpointing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZcoR4YCon7vX"
      },
      "outputs": [],
      "source": [
        "# TODO(epot): Doc on:\n",
        "# * saving only LoRA params\n",
        "# * Fuse params"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QAbX_hwileZF"
      },
      "source": [
        "## Evaluation\n",
        "\n",
        "Here, we only perform a qualitative evaluation by sampling a prompt.\n",
        "\n",
        "For more info on evals:\n",
        "\n",
        "* See the [sampling](https://github.com/google-deepmind/gemma/blob/main/docs/sampling.md) tutorial for more info on running inference.\n",
        "* To add evals during training, see the Kauldron [evaluator](https://kauldron.readthedocs.io/en/latest/eval.html) documentation.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fOrkpxlAlf2V"
      },
      "outputs": [],
      "source": [
        "sampler = gm.text.ChatSampler(\n",
        "    model=model,\n",
        "    params=state.params,\n",
        "    tokenizer=tokenizer,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x54YaAteRV94"
      },
      "source": [
        "We test a sentence, using the same formatting used during fine-tuning:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "height": 35
        },
        "executionInfo": {
          "elapsed": 469,
          "status": "ok",
          "timestamp": 1739287741376,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "yM0l9EnPMdHf",
        "outputId": "79a9780f-701d-4b4f-c93e-a818add1218c"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "\"Je me sens heureux aujourd'hui !\""
            ]
          },
          "execution_count": 26,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sampler.chat('I\\'m feeling happy today!')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sPQUGkR3ZcO_"
      },
      "source": [
        "The model correctly translated our prompt to French!"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "last_runtime": {},
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "001cd6e73ce1481ba9d534f22b3b3ad3": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "193d3862469140bc90fdef13fe7bbdfb": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "1d2dbb7248334e68b53ef5d85db6e8c2": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_fee3501cc1f24a1baa24b6fa928edabf",
            "max": 501,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_193d3862469140bc90fdef13fe7bbdfb",
            "value": 501
          }
        },
        "4d9f4c5bbf4644cb8b5a202606df0e33": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9a2ab10ffa954bf69d1c2616d7fa3f9c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "bac38805ce4a4bf186510a17c5796334": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "c24f4e4e33314fec95220d30dc122b7e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_eff97ec5bfb0423b873e987019c162be",
              "IPY_MODEL_1d2dbb7248334e68b53ef5d85db6e8c2",
              "IPY_MODEL_f75d78091d504c0a98b53697c0fca63c"
            ],
            "layout": "IPY_MODEL_f7ff1c9852834af2b43853af188258ce"
          }
        },
        "eff97ec5bfb0423b873e987019c162be": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_001cd6e73ce1481ba9d534f22b3b3ad3",
            "placeholder": "​",
            "style": "IPY_MODEL_9a2ab10ffa954bf69d1c2616d7fa3f9c",
            "value": "train: 100%"
          }
        },
        "f75d78091d504c0a98b53697c0fca63c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_4d9f4c5bbf4644cb8b5a202606df0e33",
            "placeholder": "​",
            "style": "IPY_MODEL_bac38805ce4a4bf186510a17c5796334",
            "value": " 501/501 [01:25\u0026lt;00:00, 26.52it/s]"
          }
        },
        "f7ff1c9852834af2b43853af188258ce": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "fee3501cc1f24a1baa24b6fa928edabf": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
